{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Automated Error Analysis\n",
    "\n",
    "This notebook uses LLMs to analyze and summarize the errors in the eval results across a single common dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "pd.set_option('display.max_colwidth', None)\n",
    "pd.set_option('display.max_rows', None)\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from openai import OpenAI\n",
    "import os\n",
    "import json\n",
    "import warnings\n",
    "from eval.eval import get_all_minimal_queries\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "openai = OpenAI(api_key=os.environ.get(\"OPENAI_API_KEY\"))\n",
    "model_4_latest = \"gpt-4-turbo-2024-04-09\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. Extract all results for a single dataset name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir_path = f\"results/\"\n",
    "ds_to_analyze = \"006b\" # Provide the dataset name e.g. 001, 002 etc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def list_csv_files(path):\n",
    "    # Get all file names in the directory\n",
    "    file_names = os.listdir(path)\n",
    "\n",
    "    # Filter for only .csv files\n",
    "    csv_files = [f for f in file_names if f.endswith('.csv')]\n",
    "\n",
    "    # Sort files alphabetically\n",
    "    csv_files.sort()\n",
    "\n",
    "    # Create dictionary with numbered keys\n",
    "    csv_dict = {i: csv_files[i] for i in range(len(csv_files))}\n",
    "\n",
    "    return csv_dict\n",
    "\n",
    "csv_files = list_csv_files(results_dir_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get indices of csv_files with these keywords\n",
    "keywords = [ds_to_analyze]\n",
    "keywords_to_exclude = []\n",
    "selected_models = [i for i, s in enumerate(csv_files.values()) if all(xs in s for xs in keywords)]\n",
    "if keywords_to_exclude:\n",
    "    selected_models = [i for i in selected_models if not any(xs in csv_files[i] for xs in keywords_to_exclude)]\n",
    "\n",
    "# Print selected models\n",
    "print(\"Results to analyze:\")\n",
    "for i in selected_models:\n",
    "    print(f\"{csv_files[i]}\")\n",
    "\n",
    "# Load results from csv file into dataframe\n",
    "dfs = {}\n",
    "for id in selected_models:\n",
    "    file_name = csv_files[id]\n",
    "    model = file_name.replace('.csv', '')\n",
    "    dfs[model] = pd.read_csv(results_dir_path + file_name, comment='#')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Combine all dataframes of selected models into one\n",
    "all_dfs = []\n",
    "for model in dfs:\n",
    "    temp_df = dfs[model]\n",
    "    temp_df['model'] = model\n",
    "    all_dfs.append(temp_df)\n",
    "df = pd.concat(all_dfs)\n",
    "# Apply get_all_minimal_queries to all queries\n",
    "df['true_queries'] = df['query'].apply(get_all_minimal_queries)\n",
    "df['error_msg_short'] = df['error_msg'].str.split(\"\\n\\n\").str[0].str.replace(\"QUERY EXECUTION ERROR:\", \"\")\n",
    "\n",
    "# Split model column by the last underscore \n",
    "df['eval'] = df['model'].str.rsplit(pat='_', n=1).str[1]\n",
    "df['model'] = df['model'].str.rsplit(pat='_', n=1).str[0]\n",
    "df.head(1)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. Analyze correctness by category"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the correctness by category\n",
    "df_category_correct = df.pivot_table(\"correct\", \"query_category\", aggfunc=\"mean\").sort_values('correct', ascending=False)\n",
    "df_category_correct.plot(kind='barh', color='skyblue', figsize=(10, 6))\n",
    "plt.title('Correctness by SQL category')\n",
    "plt.xlabel('Correctness')\n",
    "plt.ylabel('Category')\n",
    "# add labels\n",
    "for i, v in enumerate(df_category_correct['correct']):\n",
    "    plt.text(v, i, f\"{v*100:.2f}%\", color='black', va='center')\n",
    "ax = plt.gca()\n",
    "for spine in ['right', 'top']:\n",
    "    ax.spines[spine].set_visible(False)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert df to dict\n",
    "category_corr_dict = df_category_correct.sort_values('correct', ascending=True).to_dict()['correct']\n",
    "category_corr_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. Analyze invalid SQL with DB exec errors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get db exec error rows across all result files\n",
    "df_error_exec = df[df['error_db_exec'] == 1][['model', 'db_name', 'question', 'error_db_exec', 'error_msg_short', 'true_queries', 'generated_query']].sort_values(['db_name','question'])\n",
    "# Get questions with recurring exec errors\n",
    "df_error_exec_recurr = df_error_exec[df_error_exec.duplicated(subset=['db_name', 'question'], keep=False)][['question', 'error_db_exec', 'error_msg_short', 'true_queries', 'generated_query']]\n",
    "print(f\"{len(df_error_exec_recurr['question'].unique())} questions with recurring exec errors\")\n",
    "df_error_exec_recurr.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert error_msg_short col to a string of bullet points\n",
    "error_exec_str = \"\\n\".join([f\"- {x}\" for x in df_error_exec_recurr['error_msg_short']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract patterns from db execution error messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get error exec patterns\n",
    "def get_error_exec_patterns(\n",
    "    model: str, error_exec_str: str\n",
    ") -> str:\n",
    "    \"\"\"\n",
    "    Use LLM to extract recurring patterns in a list of error messages.\n",
    "    \"\"\"\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": f\"\"\"Your task is to identify recurring patterns in the error messages below and provide a summary of the patterns.\"\"\"\n",
    "        },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"\"\"List of error messages:\n",
    "{error_exec_str}\n",
    "\n",
    "Format your response as a numbered list of recurring patterns in the error messages. \n",
    "Each point should be a concise yet detailed summary of a trend identified in the error messages along with specific examples.\n",
    "Do not include any other information before and after the list.\n",
    "\"\"\",\n",
    "        },\n",
    "    ]\n",
    "\n",
    "    completion = openai.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "        max_tokens=1000,\n",
    "        temperature=0,\n",
    "        # top_p=0.5,\n",
    "        # response_format = {\"type\": \"json_object\"}\n",
    "    )\n",
    "    completion = completion.choices[0].message.content\n",
    "    return completion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_exec_summary = get_error_exec_patterns(model_4_latest, error_exec_str)\n",
    "print(error_exec_summary)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. Analyze valid but wrong examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get valid but correct examples across all result files\n",
    "df_valid_wrong = df[(df['correct'] == 0) & (df['error_db_exec'] == 0)][['model', 'db_name', 'query_category', 'question', 'instructions', 'correct', 'error_db_exec', 'true_queries', 'generated_query']].sort_values(['db_name','question']).fillna('')\n",
    "# Get questions that were repeatedly valid but wrong\n",
    "df_valid_wrong_recurr = df_valid_wrong[df_valid_wrong.duplicated(subset=['db_name', 'question'], keep=False)][['db_name', 'query_category', 'question', 'instructions', 'true_queries', 'generated_query']]\n",
    "print(f\"{len(df_valid_wrong_recurr['question'].unique())} unique questions that are recurring valid but wrong\")\n",
    "df_valid_wrong_recurr.head(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get first row of all duplicates\n",
    "# To reduce the number of LLM calls, we will assume that all duplicates are wrong in the same way\n",
    "df_valid_wrong_recurr_first = df_valid_wrong_recurr.drop_duplicates(subset=['db_name', 'question'], keep='first')[['query_category', 'question', 'instructions', 'true_queries', 'generated_query']]\n",
    "df_valid_wrong_recurr_first.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get reasons for valid but wrong examples\n",
    "def explain_incorrect(\n",
    "    model: str, question: str, instructions: str, true_sqls: list, generated_sql: str\n",
    ") -> str:\n",
    "    \"\"\"\n",
    "    Use LLM to explain why a SQL query is incorrect given a question, instructions and the true SQL queries.\n",
    "    \"\"\"\n",
    "    if instructions:\n",
    "        instructions = f\"\\nInstructions: {instructions}\\n\"\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": f\"\"\"Your task is to explain why the SQL query is incorrect given the question, instructions and the true SQL queries.\"\"\"\n",
    "        },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"\"\"Question: {question}{instructions}\n",
    "Incorrect SQL: {generated_sql}\n",
    "\n",
    "True SQL queries:\n",
    "{true_sqls}\n",
    "\n",
    "Format your response as a valid JSON string with reason as a key. \n",
    "Your response should look like the string below:\n",
    "{{ \"reason\": \"Your reasoning for why the SQL query is incorrect according to the question and the true SQL queries.\"\n",
    "}}\n",
    "\n",
    "Do not include any other information before and after the JSON string.\n",
    "\"\"\",\n",
    "        },\n",
    "    ]\n",
    "\n",
    "    completion = openai.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "        max_tokens=500,\n",
    "        temperature=0,\n",
    "        # top_p=0.5,\n",
    "        response_format = {\"type\": \"json_object\"}\n",
    "    )\n",
    "    completion = completion.choices[0].message.content\n",
    "    try:\n",
    "        completion_dict = json.loads(completion)\n",
    "    except:\n",
    "        print(f\"Error parsing completion {completion}\", flush=True)\n",
    "        completion_dict = {\"reason\": None}\n",
    "    reason = completion_dict.get(\"reason\", None)\n",
    "    return reason"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.notebook import tqdm\n",
    "tqdm.pandas()\n",
    "# Get explanations for valid but wrong examples\n",
    "df_valid_wrong_recurr_first['reason_incorrect'] = df_valid_wrong_recurr_first.progress_apply(lambda x: explain_incorrect(model_4_latest, x['question'], x['instructions'], x['true_queries'], x['generated_query']), axis=1)\n",
    "df_valid_wrong_recurr_first"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Convert reason_incorrect col to a string of bullet points\n",
    "reason_incorrect_str = \"\\n\".join([f\"- {x}\" for x in df_valid_wrong_recurr_first['reason_incorrect']])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get recurring patterns from reasons of valid but wrong examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get reasons for valid but wrong examples\n",
    "def get_valid_wrong_patterns(\n",
    "    model: str, reason_incorrect_str: str\n",
    ") -> str:\n",
    "    \"\"\"\n",
    "    Use LLM to extract recurring patterns in a list of error messages that describe why a SQL query is wrong according to the question and the true SQL queries.\n",
    "    \"\"\"\n",
    "    messages = [\n",
    "        {\n",
    "            \"role\": \"system\",\n",
    "            \"content\": f\"\"\"Your task is to identify recurring patterns in the error messages below and provide a summary of the patterns.\"\"\"\n",
    "        },\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": f\"\"\"List of error messages that describe why a SQL query is wrong according to the question and the true SQL queries:\n",
    "{reason_incorrect_str}\n",
    "\n",
    "Format your response as a numbered list of recurring patterns in the error messages. \n",
    "Each point should be a concise yet detailed summary of a trend identified in the error messages along with specific examples (e.g. inability to follow instructions, common errors in specific SQL categories, etc.).\n",
    "Do not include any other information before and after the list.\n",
    "\"\"\",\n",
    "        },\n",
    "    ]\n",
    "\n",
    "    completion = openai.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "        max_tokens=1000,\n",
    "        temperature=0,\n",
    "        # top_p=0.5,\n",
    "        # response_format = {\"type\": \"json_object\"}\n",
    "    )\n",
    "    completion = completion.choices[0].message.content\n",
    "    return completion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_wrong_summary = get_valid_wrong_patterns(model_4_latest, reason_incorrect_str)\n",
    "print(valid_wrong_summary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Store summaries in a json file\n",
    "summary_dict = {\n",
    "    \"category_corr_dict\": category_corr_dict,\n",
    "    \"error_exec_summary\": error_exec_summary,\n",
    "    \"valid_wrong_summary\": valid_wrong_summary\n",
    "}\n",
    "output_file = f\"{results_dir_path}error_analysis_ds_{ds_to_analyze}.json\"\n",
    "with open(output_file, 'w') as f:\n",
    "    json.dump(summary_dict, f, indent=4)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "defog",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
